{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "n_test_batches = 200"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 11 - Secure Deep Learning Classification \n",
    "\n",
    "\n",
    "\n",
    "## Your data matters, your model too\n",
    "\n",
    "Data is the driver behind Machine Learning. Organizations who create and collect data are able to build and train their own machine learning models. This allows them to offer the use of such models as a service (MLaaS) to outside organizations. This is useful as other organizations who might not be able to create these models themselves but who still would like to use this model to make predictions on their own data. \n",
    "\n",
    "However, a model hosted in the cloud still presents a privacy/IP issue. In order for external organizations to use it - they must either upload their input data (such as images to be classified) or download the model. Uploading input data can be problematic from a privacy perspective, but downloading the model might not be an option if the organization who created/owns the model is worried about losing their IP.\n",
    "\n",
    "\n",
    "## Computing over encrypted data\n",
    "\n",
    "In this context, one potential solution is to encrypt both the model and the data in a way which allows one organization to use a model owned by another organization without either disclosing their IP to one another. Several encryption schemes exist that allow for computation over encrypted data, among which Secure Multi-Party Computation (SMPC), Homomorphic Encryption (FHE/SHE) and Functional Encryption (FE) are the most well known types. We will focus here on Secure Multi-Party Computation ([introduced in detail here in tutorial 5](https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%205%20-%20Intro%20to%20Encrypted%20Programs.ipynb)) which consists of private additive sharing. It relies on crypto protocols such as SecureNN and SPDZ, the details of which are given [in this excellent blog post](https://mortendahl.github.io/2017/09/19/private-image-analysis-with-mpc/). \n",
    "\n",
    "These protocols achieve remarkable performances over encrypted data, and over the past few months we have been working to make these protocols easy to use. Specifically, we're building tools to allow you to use these protocols without having to re-implement the protocol yourself (or even necessarily know the cryptography behind how it works). Let's jump right in.\n",
    "\n",
    "## Set up\n",
    "\n",
    "The exact setting in this tutorial is the following: consider that you are the server and you have some data. First, you define and train a model with this private training data. Then, you get in touch with a client who holds some of their own data who would like to access your model to make some predictions. \n",
    "\n",
    "You encrypt your model (a neural network). The client encrypts their data. You both then use these two encrypted assets to use the model to classify the data. Finally, the result of the prediction is sent back to the client in an encrypted way so that the server (_i.e._ you) learns nothing about the client's data (you learn neither the inputs or the prediction).\n",
    "\n",
    "Ideally we would additively share the `client`'s input between itself and the `server` and vice versa for the model. For the sake of simplicity, the shares will be held by two other workers `alice` and `bob`. If you consider that alice is owned by the client and bob by the server, it's completely equivalent.\n",
    "\n",
    "The computation is secure in the honest-but-curious adversary model which is standard in [many MPC frameworks](https://arxiv.org/pdf/1801.03239.pdf).\n",
    "\n",
    "**We have now everything we need, let's get started!**\n",
    "\n",
    "\n",
    "Author:\n",
    "- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel) · GitHub: [@LaRiffle](https://github.com/LaRiffle)\n",
    "\n",
    "**Let's get started!**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports and model specifications"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "   We also need to execute commands specific to importing/starting PySyft. We create a few workers (named `client`, `bob`, and `alice`). Lastly, we define the `crypto_provider` who gives all the crypto primitives we may need ([See our tutorial on SMPC for more details](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2009%20-%20Intro%20to%20Encrypted%20Programs.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy\n",
    "hook = sy.TorchHook(torch) \n",
    "client = sy.VirtualWorker(hook, id=\"client\")\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n",
    "crypto_provider = sy.VirtualWorker(hook, id=\"crypto_provider\") "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define the setting of the learning task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size = 64\n",
    "        self.test_batch_size = 50\n",
    "        self.epochs = epochs\n",
    "        self.lr = 0.001\n",
    "        self.log_interval = 100\n",
    "\n",
    "args = Arguments()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data loading and sending to workers\n",
    "\n",
    "In our setting, we assume that the server has access to some data to first train its model. Here is the MNIST training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=args.batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Second, the client has some data and would like to have predictions on it using the server's model. This client encrypts its data by sharing it additively across two workers `alice` and `bob`.\n",
    "> SMPC uses crypto protocols which require to work on integers. We leverage here the PySyft tensor abstraction to convert PyTorch Float tensors into Fixed Precision Tensors using `.fix_precision()`. For example 0.123 with precision 2 does a rounding at the 2nd decimal digit so the number stored is the integer 12. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=args.test_batch_size, shuffle=True)\n",
    "\n",
    "private_test_loader = []\n",
    "for data, target in test_loader:\n",
    "    private_test_loader.append((\n",
    "        data.fix_precision().share(alice, bob, crypto_provider=crypto_provider),\n",
    "        target.fix_precision().share(alice, bob, crypto_provider=crypto_provider)\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Feed Forward Neural Network specification\n",
    "Here is the network specification used by the server"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(784, 500)\n",
    "        self.fc2 = nn.Linear(500, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 784)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Launch the training\n",
    "The training is done locally so this is pure local PyTorch training, nothing special here!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(args, model, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        output = F.log_softmax(output, dim=1)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * args.batch_size, len(train_loader) * args.batch_size,\n",
    "                100. * batch_idx / len(train_loader), loss.item()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n",
    "\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(args, model, train_loader, optimizer, epoch)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(args, model, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            output = model(data)\n",
    "            output = F.log_softmax(output, dim=1)\n",
    "            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n",
    "            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability \n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        100. * correct / len(test_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test(args, model, test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our model is now trained and ready to be provided as a service!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Secure evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, as the server, we send the model to the workers holding the data. Because the model is sensitive information (you've spent time optimizing it!), you don't want to disclose its weights so you secret share the model just like we did with the dataset earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fix_precision().share(alice, bob, crypto_provider=crypto_provider)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This test function performs the encrypted evaluation. The model weights, the data inputs, the prediction and the target used for scoring are encrypted!\n",
    "\n",
    "However, the syntax is very similar to pure PyTorch testing of a model, isn't it nice?!\n",
    "\n",
    "The only thing we decrypt from the server side is the final score at the end to verify predictions were on average good."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(args, model, test_loader):\n",
    "    model.eval()\n",
    "    n_correct_priv = 0\n",
    "    n_total = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader[:n_test_batches]:\n",
    "            output = model(data)\n",
    "            pred = output.argmax(dim=1) \n",
    "            n_correct_priv += pred.eq(target.view_as(pred)).sum()\n",
    "            n_total += args.test_batch_size\n",
    "# This 'test' function performs the encrypted evaluation. The model weights, the data inputs, the prediction and the target used for scoring are all encrypted!\n",
    "\n",
    "# However as you can observe, the syntax is very similar to normal PyTorch testing! Nice!\n",
    "\n",
    "# The only thing we decrypt from the server side is the final score at the end of our 200 items batches to verify predictions were on average good.      \n",
    "            n_correct = n_correct_priv.copy().get().float_precision().long().item()\n",
    "    \n",
    "            print('Test set: Accuracy: {}/{} ({:.0f}%)'.format(\n",
    "                n_correct, n_total,\n",
    "                100. * n_correct / n_total))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test(args, model, private_test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Et voilà! Here you are, you have learned how to do end to end secure predictions: the weights of the server's model have not leaked to the client and the server has no information about the data input nor the classification output!\n",
    "\n",
    "Regarding performance, classifying one image takes **less than 0.1 second**, approximately **33ms** on my laptop (2,7 GHz Intel Core i7, 16GB RAM). However, this is using very fast communication (all the workers are on my local machine). Performance will vary depending on how fast different workers can talk to each other."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "You have seen how easy it is to leverage PyTorch and PySyft to perform practical Secure Machine Learning and protect users data, without having to be a crypto expert!\n",
    "\n",
    "More on this topic will come soon, including convolutional layers to properly benchmark PySyft performance with respect to other libraries, as well as private encrypted training of neural networks, which is needed when a organisation resorts to external sensitive data to train its own model. Stay tuned!\n",
    "\n",
    "If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways! \n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! \n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
